import numpy as np
import matplotlib.pyplot as plt
import os
import pdb
import random
import time
import types

#Default values:
num_states=5
num_actions=2
num_agents=6

def set_hyps(a,a_default):
    if(a is None):
        return a_default
    else:
        return a

def set_seed(seed=1):
    if seed is not None:
        np.random.seed(seed)    
        random.seed(seed)

def env_setup(seed_init=1,state_space=None,action_spaces=None,rho=None,transP=None,reward=None,xi=[0,0],gamma=0.95):
    #K constraints
    
    env_dict={}
    set_seed(seed_init)
    env_dict['seed_init']=seed_init
    
    env_dict['state_space']=set_hyps(a=state_space,a_default=range(num_states))
    env_dict['action_spaces']=set_hyps(a=action_spaces,a_default=[list(range(num_actions))]*num_agents)  
    env_dict['num_states']=len(env_dict['state_space'])
    env_dict['num_actions']=[len(tmp) for tmp in env_dict['action_spaces']].copy()    #env_dict['num_actions'][m] for agent m
    env_dict['num_agents']=len(env_dict['num_actions'])
    
    if rho is None:
        rho=np.random.normal(size=(env_dict['num_states']))
    else:
        if isinstance(rho, list):
            rho=np.array(rho)
        assert rho.size==env_dict['num_states'], "rho should have "+str(env_dict['num_states'])+" entries."
    rho=np.abs(rho).reshape(env_dict['num_states'])
    env_dict['rho']=rho/rho.sum()
    
    transP_shape=tuple([env_dict['num_states']]+env_dict['num_actions']+[env_dict['num_states']])
    if transP is None:
        env_dict['transP']=np.abs(np.random.normal(size=transP_shape))   #P(s,a1,...,aM,s')
        env_dict['transP']=env_dict['transP']/np.sum(env_dict['transP'],axis=env_dict['num_agents']+1,keepdims=True)
    else:
        assert transP.shape == transP_shape, \
            "transP should have shape: (num_states,num_actions1,...,num_actionsM,num_states)"
        transP=np.abs(transP)
        env_dict['transP']=transP/np.sum(transP,axis=env_dict['num_agents']+1,keepdims=True)

    env_dict['gamma']=gamma
    newshape=tuple([1,]*(env_dict['num_agents']+1)+[env_dict['num_states']])
    env_dict['P_rho']=env_dict['gamma']*env_dict['transP']+(1-env_dict['gamma'])*env_dict['rho'].reshape(newshape)
    
    if isinstance(xi, list):
        xi=np.array(xi).reshape(-1)
    K=xi.shape[0]
    env_dict['K']=K
    env_dict['xi']=xi.copy()
        
    if (reward is not None):
        reward_shape=tuple([env_dict['num_states']]+env_dict['num_actions']+[env_dict['num_states']]+[env_dict['num_agents']]+[K+1])
        assert reward.shape==reward_shape,\
            "reward should be either None or an np.array with shape (num_states,num_actions1,...,num_actionsM,num_states,num_agents,K+1)"
    else:
        reward_shape=tuple([env_dict['num_states']]+env_dict['num_actions']+[env_dict['num_states']]+[env_dict['num_agents']]+[K+1])
    env_dict['reward'] = set_hyps(a=reward,a_default=np.random.uniform(size=reward_shape))
    env_dict['reward_agentavg']=env_dict['reward'].mean(axis=env_dict['num_agents']+2)
    
    return env_dict

def get_transP_s2s(pim,transP): #From P(s'|s,a), obtain P(s'|s)
    num_agents=len(pim)
    num_states=transP.shape[0]
    transP_s2s=transP.copy()
    for m in range(num_agents):
        num_actions=pim[m].shape[1]
        newshape=(num_states,)+(1,)*m+(num_actions,)+(1,)*(num_agents-m)
        transP_s2s*=pim[m].reshape(newshape)
    newshape=tuple(range(1,num_agents+1))
    return transP_s2s.sum(axis=newshape)

def stationary_dist(transP_s2s):  #Stationary distribution corresponding to transP_s2s
    evals, evecs = np.linalg.eig(transP_s2s.T)  #P.T*evecs=evecs*np.diag(evals)
    evec1 = evecs[:, np.isclose(evals, 1)]
    evec1 = np.abs(evec1[:, 0])
    stationary = evec1 / evec1.sum()
    return stationary.real

def Vk_rho_func(pim,transP,P_rho,reward_agentavg_k,gamma,nu_pi=None):  
    num_agents=len(pim)
    num_states=transP.shape[0]
    
    if nu_pi is None:
        P_rho_s2s=get_transP_s2s(pim,P_rho)
        nu_pi=stationary_dist(P_rho_s2s)
    
    Vk_rho=(reward_agentavg_k*transP).sum(axis=num_agents+1)
    for m in range(num_agents-1,-1,-1):
        num_actions=pim[m].shape[1]
        newshape=(num_states,)+(1,)*m+(num_actions,)
        Vk_rho*=pim[m].reshape(newshape)
        Vk_rho=Vk_rho.sum(axis=m+1)
    Vk_rho=(Vk_rho*nu_pi).sum()/(1-gamma)
    return Vk_rho

def Vk_s_func(pim,transP,reward_agentavg_k,gamma):
    num_states=transP.shape[0]
    num_agents=len(pim)
    Vk_s=np.zeros(num_states)
    for s in range(num_states):
        P_rho=np.zeros_like(transP)
        exec("P_rho["+":,"*(num_agents+1)+"s]=1-gamma")
        P_rho+=gamma*transP
        Vk_s[s]=Vk_rho_func(pim,transP,P_rho,reward_agentavg_k,gamma,nu_pi=None)
    return Vk_s

def Qkm_func(pim,transP,reward_agentavg_k,gamma,Vk_s=None):  #return matrix Q_k^{(m)}(s,a^{(m)}) for agent m and k-th reward/safety score
    if Vk_s is None:
        Vk_s=Vk_s_func(pim,transP,reward_agentavg_k,gamma)
    num_states=transP.shape[0]
    num_agents=len(pim)
    Qkm_pre=(transP*(reward_agentavg_k+gamma*Vk_s.reshape((1,)*(num_agents+1)+(num_states,)))).sum(axis=num_agents+1)
    #Qkm_pre(s,a)=sum_{s'} transP(s'|s,a)*[reward_agentavg_k(s,a,s')+gamma*Vk_s(s')]
    Qkm=[0]*num_agents
    for m in range(num_agents):
        if m>0:
            num_actions=pim[m-1].shape[1]
            newshape=(num_states,num_actions,)+(1,)*(num_agents-m)
            Qkm_pre=(pim[m-1].reshape(newshape)*Qkm_pre).sum(axis=1)
        Qkm_now=Qkm_pre.copy()
        for m2 in range(m+1,num_agents):
            num_actions=pim[m2].shape[1]
            newshape=(num_states,)+(1,)+(num_actions,)+(1,)*(num_agents-m2-1)
            Qkm_now=(Qkm_now*pim[m2].reshape(newshape)).sum(axis=2)
        Qkm[m]=Qkm_now.copy()
    return Qkm
    # return [Qkm.sum(axis=tuple(range(1,m+1))+tuple(range(m+2,num_agents+1))) for m in range(num_agents)]

def proj_Pr(y):  #Project vector x into probability space
    u=np.flip(np.sort(y))
    D=u.shape[0]
    y=y.reshape(-1)
    a=1-np.sum(u)
    sum_remain=1
    for j in range(D):
        sum_remain-=u[j]
        lambda_now=sum_remain/(j+1)
        if lambda_now+u[j]>0:
            lambda_save=lambda_now
    x=y+lambda_save
    x[x<0]=0
    return x

def PrimalDual_population(env_dict,unconstrained_alg="exact",T=100,T_Viter=50,alpha_pi=0.1,beta_lambda=0.1,lambda_kmax=None,pi0=None,is_print=False,is_save=False,save_folder="results/PrimalDual_population"):
    start_time=time.time()
    env_dict=env_dict.copy()
    if lambda_kmax is None:
        lambda_kmax=np.ones(env_dict['K'])*10.0
    if (pi0 is None):
        pi0=[np.ones((env_dict['num_states'],env_dict['num_actions'][m]))/env_dict['num_actions'][m] for m in range(env_dict['num_agents'])]

    PD_dict={'pim':[-1]*(T+1)}
    PD_dict['pim'][0]=pi0.copy()

    keys=['unconstrained_alg','T','T_Viter','alpha_pi']
    keys+=['beta_lambda','lambda_kmax','is_print','is_save']
    for key in keys:
        PD_dict[key]=eval(key)

    PD_dict['lambda']=np.zeros((T+1,env_dict['K']))
    PD_dict['Vk_rho']=np.zeros((env_dict['K']+1,T+1))
    num_joint_actions=np.prod(env_dict['num_actions'])

    for k in range(env_dict['K']+1):
        Vk_rho=Vk_rho_func(pi0,env_dict['transP'],env_dict['P_rho'],\
                           np.take(env_dict['reward_agentavg'],k,axis=-1),env_dict['gamma'],nu_pi=None)
        PD_dict['Vk_rho'][k,0]=Vk_rho
        
    for t in range(T):      #t-th outer primal-dual iteration        
        if is_print:
            tmp=": "
            for k in range(env_dict['K']):
                tmp+="V_"+str(k)+"="+str(PD_dict['Vk_rho'][k,t])+", "
            tmp+="V_"+str(env_dict['K'])+"="+str(PD_dict['Vk_rho'][env_dict['K'],t])+"."
            print("Iteration "+str(t)+tmp)

        R_hat=(np.take(env_dict['reward_agentavg'],range(1,env_dict['K']+1),axis=-1)*\
            (PD_dict['lambda'][t].reshape((1,)*(env_dict['num_agents']+2)+PD_dict['lambda'][t].shape))).sum(axis=-1)\
            +np.take(env_dict['reward_agentavg'],0,axis=-1)   
            #R_hat[s,a1,...,aM,s'] is the agents' average surrogate reward for unconstrained MARL
        
        if unconstrained_alg=="exact":  #Use value iteration
            V=np.zeros(env_dict['num_states'])
            for t_Viter in range(T_Viter):
                V=((R_hat+(env_dict['gamma']*V).reshape((1,)*(env_dict['num_agents']+1)+(env_dict['num_states'],)))*env_dict['transP'])\
                    .sum(axis=-1).max(axis=tuple(range(1,env_dict['num_agents']+1)))
            a_opt=((R_hat+(env_dict['gamma']*V).reshape((1,)*(env_dict['num_agents']+1)+(env_dict['num_states'],)))*env_dict['transP'])\
                .sum(axis=-1)
            a_opt = a_opt.reshape(a_opt.shape[0],-1).argmax(1)
            a_opt = np.column_stack(np.unravel_index(a_opt, env_dict['num_actions']))
            
            PD_dict['pim'][t+1]=[0]*env_dict['num_agents']
            for m in range(env_dict['num_agents']):
                pim=np.zeros((env_dict['num_states'],env_dict['num_actions'][m]))
                np.put(pim,np.ravel_multi_index(np.row_stack((range(env_dict['num_states']),a_opt[:,m])),pim.shape), 1)
                PD_dict['pim'][t+1][m]=pim.copy()
                
            for k in range(env_dict['K']+1):
                Vk_rho=Vk_rho_func(PD_dict['pim'][t+1],env_dict['transP'],env_dict['P_rho'],\
                                   np.take(env_dict['reward_agentavg'],k,axis=-1),env_dict['gamma'],nu_pi=None)
                PD_dict['Vk_rho'][k,t+1]=Vk_rho
            
        else:    #policy gradient 
            pim=PD_dict['pim'][t].copy()
            P_rho_s2s=get_transP_s2s(pim,env_dict['P_rho'])
            nu_pi=stationary_dist(P_rho_s2s)
            Qkm=Qkm_func(pim,env_dict['transP'],R_hat,env_dict['gamma'],Vk_s=None)
            for m in range(env_dict['num_agents']):
                pim[m]+=((alpha_pi/(1-env_dict['gamma']))*nu_pi).reshape((-1,1))*Qkm[m]
                for s in range(env_dict['num_states']):
                    pim[m][s]=proj_Pr(pim[m][s])                
            PD_dict['pim'][t+1]=pim.copy()
            
            for k in range(env_dict['K']+1):
                Vk_rho=Vk_rho_func(pim,env_dict['transP'],env_dict['P_rho'],\
                                   np.take(env_dict['reward_agentavg'],k,axis=-1),env_dict['gamma'],nu_pi=None)
                PD_dict['Vk_rho'][k,t+1]=Vk_rho
        
        lambda_tmp=PD_dict['lambda'][t]-beta_lambda*(Vk_rho-env_dict['xi'])
        lambda_tmp=np.maximum(lambda_tmp,0)
        lambda_tmp=np.minimum(lambda_tmp,lambda_kmax)

        PD_dict['lambda'][t+1]=lambda_tmp.copy()
    t+=1    
    if is_print:
        tmp=": "
        for k in range(env_dict['K']):
            tmp+="V_"+str(k)+"="+str(PD_dict['Vk_rho'][k,t])+", "
        tmp+="V_"+str(env_dict['K'])+"="+str(PD_dict['Vk_rho'][env_dict['K'],t])+"."
        print("Iteration "+str(t)+tmp)     
    PD_dict['time(s)']=time.time()-start_time
    
    if is_save:
        if not os.path.isdir(save_folder):
            os.makedirs(save_folder)
        
        np.save(file=save_folder+'/Vk_rho.npy',arr=PD_dict['Vk_rho'])
        np.save(file=save_folder+'/lambda.npy',arr=PD_dict['lambda'])
        
        for m in range(env_dict['num_agents']):
            np.save(file=save_folder+'/pim_agent'+str(m)+'.npy',arr=PD_dict['pim'][m])
    
        hyp_txt=open(save_folder+'/hyperparameters.txt','a')
        
        for key in keys:
            hyp_txt.write(key+'='+str(PD_dict[key])+'\n')
        hyp_txt.write('Time consumption: '+str(PD_dict['time(s)']/60)+' minutes\n')
        hyp_txt.close()  #!!
    return PD_dict
    
def primal_population(env_dict,T=100,alpha=0.1,tol=0.1,n_between_evals=5,pi0=None,is_print=False,is_save=False,save_folder="results/primal_population"):
    start_time=time.time()
    env_dict=env_dict.copy()
    primal_dict={'alpha':alpha,'tol':tol,'T':T,'n_between_evals':n_between_evals,'is_print':is_print,'is_save':is_save}
    
    if (pi0 is None):
        pi0=[np.ones((env_dict['num_states'],env_dict['num_actions'][m]))/env_dict['num_actions'][m] for m in range(env_dict['num_agents'])]
    primal_dict['pim']=[-1]*(T+1)
    primal_dict['pim'][0]=pi0.copy()

    primal_dict['kt']=[0]*T
    primal_dict['iters_eval']=None  #The iteration indexes to save Vk_rho values
    if n_between_evals>0:
        primal_dict['iters_eval']=list(range(0,T,n_between_evals))
    n_evals=len(primal_dict['iters_eval'])
    primal_dict['Vk_rho']=np.zeros((env_dict['K']+1,n_evals))
    for t in range(T):
        kt=0
        is_eval=t in primal_dict['iters_eval']
        t_eval=int(t/n_between_evals)
        if n_between_evals>0:
            if t % n_between_evals==0:
                is_eval=True
        for k in range(1,env_dict['K']+1):
            Vk_rho=Vk_rho_func(primal_dict['pim'][t],env_dict['transP'],env_dict['P_rho'],\
                               np.take(env_dict['reward_agentavg'],k,axis=-1),env_dict['gamma'],nu_pi=None)
            if is_eval:
                Vk_rho=Vk_rho_func(primal_dict['pim'][t],env_dict['transP'],env_dict['P_rho'],\
                                   np.take(env_dict['reward_agentavg'],k,axis=-1),env_dict['gamma'],nu_pi=None)
                primal_dict['Vk_rho'][k,t_eval]=Vk_rho
            if kt==0:
                if Vk_rho<env_dict['xi'][k-1]-tol:
                    kt=k
                    primal_dict['kt'][t]=kt
                    if not is_eval:
                        break
        if is_eval:
            Vk_rho=Vk_rho_func(primal_dict['pim'][t],env_dict['transP'],env_dict['P_rho'],\
                               np.take(env_dict['reward_agentavg'],0,axis=-1),env_dict['gamma'],nu_pi=None)
            primal_dict['Vk_rho'][0,t_eval]=Vk_rho
        if is_print:
            if is_eval:
                tmp=": "
                for k in range(env_dict['K']):
                    tmp+="V_"+str(k)+"="+str(primal_dict['Vk_rho'][k,t_eval])+", "
                tmp+="V_"+str(env_dict['K'])+"="+str(primal_dict['Vk_rho'][env_dict['K'],t_eval])+", kt="+str(kt)+"."
                print("Iteration "+str(t)+tmp)
            else:
                print("Iteration "+str(t)+": kt="+str(kt)+".")

        Qkm=Qkm_func(primal_dict['pim'][t],env_dict['transP'],np.take(env_dict['reward_agentavg'],kt,axis=-1),env_dict['gamma'],Vk_s=None)
        primal_dict['pim'][t+1]=[0]*env_dict['num_agents']
        for m in range(env_dict['num_agents']):
            tmp=primal_dict['pim'][t][m]*np.exp(alpha*Qkm[m])
            primal_dict['pim'][t+1][m]=tmp/(tmp.sum(axis=1).reshape((-1,1)))
    
    primal_dict['time(s)']=time.time()-start_time
    
    if is_save:
        if not os.path.isdir(save_folder):
            os.makedirs(save_folder)
        
        np.save(file=save_folder+'/Vk_rho.npy',arr=primal_dict['Vk_rho'])
        np.save(file=save_folder+'/kt.npy',arr=primal_dict['kt'])
        
        for m in range(env_dict['num_agents']):
            np.save(file=save_folder+'/pim_agent'+str(m)+'.npy',arr=primal_dict['pim'][m])
    
        hyp_txt=open(save_folder+'/hyperparameters.txt','a')
        keys=['T','alpha','tol']
        keys+=['n_between_evals','is_print','is_save']
        
        for key in keys:
            hyp_txt.write(key+'='+str(primal_dict[key])+'\n')
        hyp_txt.write('Time consumption: '+str(primal_dict['time(s)']/60)+' minutes\n')
        hyp_txt.close()
    
    return primal_dict
    

